package com.rceeslca.common.xss;

import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashMap;
import java.util.Map;

/**
 * XSS过滤处理
 * @author rceeslca
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

    HttpServletRequest orgRequest;


    public XssHttpServletRequestWrapper(HttpServletRequest request){
        super(request);
        orgRequest = request;
    }


    @Override
    public ServletInputStream getInputStream() throws IOException{
        //非json类型 直接返回
        if(!checkContentTypeIsJson()){
            return super.getInputStream();
        }

        //为空，直接返回
        String json = IoUtil.readUtf8(super.getInputStream());
        if(StrUtil.isBlank(json)){
            return super.getInputStream();
        }

        //xss过滤
        json = xssEncode(json);
        final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8));
        return new ServletInputStream(){
            @Override
            public boolean isFinished() {
                return true;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener){}

            @Override
            public int read(){
                return bis.read();
            }
        };
    }


    @Override
    public String getParameter(String name){
        String value = super.getParameter(xssEncode(name));
        if(StrUtil.isNotBlank(value)){
            value = xssEncode(value);
        }
        return value;
    }


    @Override
    public String[] getParameterValues(String name){
        String[] parameters = super.getParameterValues(name);
        if(parameters==null||parameters.length==0){
            return null;
        }

        for(int i=0;i<parameters.length;i++){
            parameters[i] = xssEncode(parameters[i]);
        }
        return parameters;
    }


    @Override
    public Map<String,String[]> getParameterMap(){
        Map<String,String[]> map = new LinkedHashMap<>();
        Map<String,String[]> parameters = super.getParameterMap();
        for(String key : parameters.keySet()){
            String[] values = parameters.get(key);
            for(int i=0;i<values.length;i++){
                values[i] = xssEncode(values[i]);
            }
            map.put(key,values);
        }
        return map;
    }


    @Override
    public String getHeader(String name){
        String value = super.getHeader(xssEncode(name));
        if(StrUtil.isNotBlank(value)){
            value = xssEncode(value);
        }
        return value;
    }


    private String xssEncode(String input){
        return XssUtils.filter(input);
    }


    /**
     * 获取最原始的request
     */
    public HttpServletRequest getOrgRequest(){
        return orgRequest;
    }


    /**
     * 获取最原始的request
     */
    public static HttpServletRequest getOrgRequest(HttpServletRequest request){
        if(request instanceof XssHttpServletRequestWrapper){
            return ((XssHttpServletRequestWrapper)request).getOrgRequest();
        }
        return request;
    }


    /**
     * 判断是否是json请求，以前缀的方式
     */
    private boolean checkContentTypeIsJson()
    {
        String header = super.getHeader(HttpHeaders.CONTENT_TYPE);
        return StrUtil.startWithIgnoreCase(header,MediaType.APPLICATION_JSON_VALUE);
    }

}
